# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import field

import logging
import mesop as me

from common.storage import store_to_gcs
from common.utils import create_display_url, https_url_to_gcs_uri
from components.header import header
from components.page_scaffold import page_frame, page_scaffold
from components.stepper import stepper
from models.character_consistency import generate_character_video
from state.state import AppState
from components.library.library_chooser_button import library_chooser_button
from components.library.events import LibrarySelectionChangeEvent
from components.dialog import dialog

logger = logging.getLogger(__name__)


@me.stateclass
class PageState:
    """Test Character Consistency Page State"""

    current_step: int = 1
    max_completed_step: int = 1
    uploaded_image_gcs_uris: list[str] = field(default_factory=list) # pylint: disable=invalid-field-call
    uploaded_image_display_urls: list[str] = field(default_factory=list) # pylint: disable=invalid-field-call
    scene_prompt: str = ""
    video_prompt: str = ""
    character_description: str = ""
    candidate_image_gcs_uris: list[str] = field(default_factory=list)
    candidate_image_urls: list[str] = field(default_factory=list)
    best_image_gcs_uri: str = ""
    best_image_url: str = ""
    user_selected_image_url: str = "" # This will be a display URL
    outpainted_image_gcs_uri: str = ""
    outpainted_image_display_url: str = ""
    final_video_gcs_uri: str = ""
    final_video_url: str = ""
    status_message: str = "Ready."
    is_generating: bool = False
    info_dialog_open: bool = False


def character_consistency_page_content():
    """UI for the Character Consistency test page."""
    state = me.state(PageState)

    if state.info_dialog_open:
        with dialog(is_open=state.info_dialog_open):  # pylint: disable=not-context-manager
            me.text("About Character Consistency Test", type="headline-6")
            me.markdown("""
This page allows you to test the character consistency workflow step-by-step.

**Step 1: Upload & Prompt**
- Upload one or more reference images of a character.
- Provide a scene prompt to describe the scene you want to place the character in.

**Step 2: Select Image**
- Review the candidate images generated by the model.
- The system's recommended image will be highlighted with a bottom border.
- Select the image you want to use for the video generation.

**Step 3: Create Video**
- The prompt from Step 1 will be used to generate the video.
- You can modify the prompt or use one of the preset buttons to add to the prompt.
- Click "Generate Video" to create the final video.
""")
            with me.box(style=me.Style(margin=me.Margin(top=16))):
                me.button("Close", on_click=close_info_dialog, type="flat")

    with page_scaffold(page_name="test_character_consistency"):  # pylint: disable=not-context-manager
        with page_frame():  # pylint: disable=not-context-manager
            header(
                "Character Consistency Test",
                "person",
                show_info_button=True,
                on_info_click=open_info_dialog,
            )

            stepper(
                steps=["Upload & Prompt", "Select Image", "Create Video"],
                current_step=state.current_step,
                max_completed_step=state.max_completed_step,
                on_change=on_step_change,
            )

            with me.box(style=me.Style(margin=me.Margin(top=24))):
                if state.current_step == 1:
                    me.text("Step 1: Upload Reference Images and Provide a Scene Prompt", style=me.Style(margin=me.Margin(bottom=16)))
                    with me.box(style=me.Style(display="flex", flex_direction="row", gap=16, margin=me.Margin(bottom=16), justify_content="center")):
                        me.uploader(
                            label="Upload Reference Images",
                            on_upload=on_upload,
                            multiple=True,
                            style=me.Style(width="100%"),
                        )
                        library_chooser_button(
                            on_library_select=on_library_select,
                            button_label="Choose from Library",
                        )
                    if state.uploaded_image_gcs_uris:
                        with me.box(style=me.Style(display="flex", flex_wrap="wrap", gap=10, justify_content="center", margin=me.Margin(bottom=16))):
                            for uri in state.uploaded_image_display_urls:
                                me.image(
                                    src=uri,
                                    style=me.Style(width=200, height=200, object_fit="contain", border_radius=8),
                                )
                    me.textarea(
                        label="Scene Prompt",
                        rows=3,
                        on_input=on_prompt_input,
                        style=me.Style(width="100%", margin=me.Margin(bottom=16)),
                    )
                    me.button("Generate Alternatives", on_click=generate_alternatives, type="raised")

                if state.current_step == 2:
                    me.text("Step 2: Select the Best Image", style=me.Style(margin=me.Margin(bottom=16)))
                    if state.character_description:
                        with me.expansion_panel(title="Character Description"):
                            me.text(state.character_description)
                    if state.candidate_image_urls:
                        with me.box(
                            style=me.Style(
                                display="flex",
                                flex_wrap="wrap",
                                gap=10,
                                justify_content="center",
                            )
                        ):
                            for url in state.candidate_image_urls:
                                is_system_selected = url == state.best_image_url
                                is_user_selected = url == state.user_selected_image_url
                                with me.box(
                                    key=url,  # Use the URL as a key for the event
                                    on_click=on_select_image_click,
                                    style=me.Style(
                                        padding=me.Padding.all(4),
                                        border=me.Border.all(
                                            me.BorderSide(
                                                width=4,
                                                style="solid",
                                                color=me.theme_var("primary") if is_user_selected else "transparent",
                                            )
                                        ),
                                        border_radius=12,
                                        cursor="pointer",
                                    ),
                                ):
                                    with me.box(style=me.Style(
                                        border=me.Border(
                                            bottom=me.BorderSide(
                                                width=4,
                                                style="solid",
                                                color=me.theme_var("secondary") if is_system_selected else "transparent",
                                            )
                                        )
                                    )):
                                        me.image(
                                            src=url,
                                            style=me.Style(
                                                width=200,
                                                height=200,
                                                border_radius=8,
                                            ),
                                        )
                        me.button("Continue", on_click=next_step, type="raised")

                if state.current_step == 3:
                    me.text("Step 3: Create Video", style=me.Style(margin=me.Margin(bottom=16)))
                    with me.box(style=me.Style(display="flex", flex_direction="row", gap=16, margin=me.Margin(bottom=16))):
                        with me.box(style=me.Style(flex_grow=1)):
                            me.textarea(
                                label="Video Prompt",
                                rows=3,
                                value=state.video_prompt,
                                on_blur=on_video_prompt_blur,
                                style=me.Style(width="100%"),
                            )
                            with me.box(style=me.Style(display="flex", flex_direction="row", gap=16, margin=me.Margin(top=16))):
                                me.button("dancing in the rain", on_click=lambda e: on_modify_prompt_click("dancing in the rain"), type="flat")
                                me.button("camera zooms out to show the earth", on_click=lambda e: on_modify_prompt_click("camera zooms out to show the earth"), type="flat")
                                me.button("person turns magically invisible", on_click=lambda e: on_modify_prompt_click("person turns magically invisible"), type="flat")
                        if state.user_selected_image_url:
                            me.image(
                                src=state.user_selected_image_url,
                                style=me.Style(width=200, height=200, border_radius=8),
                            )
                    me.button("Generate Video", on_click=generate_video, type="raised")
                    if state.final_video_url:
                        me.video(src=state.final_video_url, style=me.Style(width=600, height=338))

            me.text(state.status_message, style=me.Style(margin=me.Margin(top=24)))


def open_info_dialog(e: me.ClickEvent):
    """Open the info dialog."""
    me.state(PageState).info_dialog_open = True
    yield


def close_info_dialog(e: me.ClickEvent):
    """Close the info dialog."""
    me.state(PageState).info_dialog_open = False
    yield


def on_upload(e: me.UploadEvent):
    """Handle image uploads."""
    state = me.state(PageState)
    for file in e.files:
        gcs_url = store_to_gcs(
            "character_consistency_references",
            file.name,
            file.mime_type,
            file.getvalue(),
        )
        state.uploaded_image_gcs_uris.append(gcs_url)
        state.uploaded_image_display_urls.append(create_display_url(gcs_url))
    yield


def on_library_select(e: LibrarySelectionChangeEvent):
    """Handle image selection from the library."""
    state = me.state(PageState)
    state.uploaded_image_gcs_uris.append(e.gcs_uri)
    state.uploaded_image_display_urls.append(create_display_url(e.gcs_uri))
    yield


def on_prompt_input(e: me.InputEvent):
    """Handle prompt input."""
    me.state(PageState).scene_prompt = e.value


def on_video_prompt_blur(e: me.InputBlurEvent):
    """Handle video prompt blur."""
    me.state(PageState).video_prompt = e.value
    yield


def on_modify_prompt_click(modifier: str):
    """Modify the video prompt."""
    state = me.state(PageState)
    state.video_prompt = f"{state.scene_prompt} {modifier}"
    yield


def next_step(e: me.ClickEvent):
    """Move to the next step."""
    state = me.state(PageState)
    state.current_step += 1
    if state.current_step > state.max_completed_step:
        state.max_completed_step = state.current_step
    state.status_message = f"Moved to step {state.current_step}"
    if state.current_step == 3:
        state.video_prompt = state.scene_prompt
    yield


def on_step_change(step_index: int):
    """Handle step changes."""
    state = me.state(PageState)
    if step_index + 1 <= state.max_completed_step:
        state.current_step = step_index + 1
    yield


def on_select_image_click(e: me.ClickEvent):
    """Handle image selection."""
    me.state(PageState).user_selected_image_url = e.key
    yield


def generate_alternatives(e: me.ClickEvent):
    """Generate alternative images."""
    state = me.state(PageState)
    app_state = me.state(AppState)
    state.is_generating = True
    yield

    try:
        for step_result in generate_character_video(
            user_email=app_state.user_email,
            reference_image_gcs_uris=state.uploaded_image_gcs_uris,
            scene_prompt=state.scene_prompt,
        ):
            state.status_message = step_result.message
            yield

            if "character_description" in step_result.data:
                state.character_description = step_result.data["character_description"]
            if "candidate_image_gcs_uris" in step_result.data:
                gcs_uris = step_result.data["candidate_image_gcs_uris"]
                state.candidate_image_gcs_uris = gcs_uris
                state.candidate_image_urls = [create_display_url(uri) for uri in gcs_uris]
            if "best_image_gcs_uri" in step_result.data:
                gcs_uri = step_result.data["best_image_gcs_uri"]
                state.best_image_gcs_uri = gcs_uri
                state.best_image_url = create_display_url(gcs_uri)
                break
    except Exception as e:
        logger.error("Error during character consistency generation", exc_info=True)
        state.status_message = f"Error: {str(e)}"
        state.is_generating = False
        yield
        return

    state.is_generating = False
    state.status_message = "Generated alternatives."
    state.current_step = 2
    if state.current_step > state.max_completed_step:
        state.max_completed_step = state.current_step
    yield


def generate_video(e: me.ClickEvent):
    """Generate the final video."""
    state = me.state(PageState)
    app_state = me.state(AppState)
    state.is_generating = True
    yield

    # Convert the display URL back to a GCS URI before passing it to the model.
    gcs_uri = https_url_to_gcs_uri(state.user_selected_image_url)

    for step_result in generate_character_video(
        user_email=app_state.user_email,
        reference_image_gcs_uris=[gcs_uri],
        scene_prompt=state.video_prompt,
    ):
        if "video_gcs_uri" in step_result.data:
            video_gcs_uri = step_result.data["video_gcs_uri"]
            state.final_video_gcs_uri = video_gcs_uri
            state.final_video_url = create_display_url(video_gcs_uri)
            break

    state.is_generating = False
    state.status_message = "Generated video."
    yield


@me.page(path="/test_character_consistency")
def page():
    character_consistency_page_content()
